Tag

Showing blog posts sorted under the tag: neural network

Interactive Digit Classification Using Neural Network Trained on MNIST Data

Several years ago, I created a fully connected neural network from scratch in C as a learning exercise. I followed the first few chapters of Michael Nielson's book 'Neural Networks and Deep Learning' that I highly recommend.

The network was designed to train on the MNIST number dataset, which is a well-known dataset used in many machine learning examples. The goal is to identify hand-written digits as any number between 0 and 9. The final network performed quite well and achieved 97.14% accuracy on the test dataset. Not bad for a bit of matrix algebra wrapped up in some C code.

Anyways, ever since then I've had the idea to create a little browser widget to let people use the model I trained in an interactive way. Of course, I was beaten to the punch once, twice, and many more times I'm sure. But even still I wanted to see how well my model would perform at this task.

Homemode Fully Connected Neural Network

Before starting to work on the widget, I beefed up my neural network a little bit and was able to train one with a 98.2% accuracy on the test MNIST data. I then used a web framework called Svelte to create a drawing and predicting widget. Since my model is all simple linear algebra, exporting the weights from C and hard-coding them into Javascript was not too much work. Libraries like Math.js made it pretty easy to recreate everything. The final product is the widget you see below. It runs entirely client side in the browser using my trained neural network.

* does not run on iOS Safari, possible macOS Safari as well *

If you tried a few numbers, you probably noticed that the predictions can often be rather poor. I found that it has a really difficult time with '1's , '0's, and '9's. It was a bit disappointing, even with a 98.2% accuracy on test data, it still has a lot of trouble with new numbers. My guess is that, due to the fully-connectedness of the network, it has a difficult time generalizing new data. Like, if a '1' is off to the side or at a wrong angle that isn't present in the training data then it will predict incorrectly.

Keras/Tensorflow Convolutional Neural Network

Another type of network often used on the MNIST data is a convolutional network. I won't go into the topic here but this explanation was pretty helpful for my understanding. Convolutional networks work so well on MNIST that it's actually one of the 'getting started' examples for Keras.

I wanted to see how much the widget would improve with a convolutional network instead of my fully-connected version. So, I followed the Keras example and trained one in Python that reached an accuracy of 99.3% on the test data. Crucially though, I believe that it generalizes much better and is therefore more tolerant to digits that may not be presented in exactly the same way as in the training data. And the results definitely show, in my testing it seems to predict the correct digit much more often then my homemade model.

Again, the widget below is running entirely in the browser using the the Tensorflow.js library. Tenforflow.js allowed me to export the model from Python and import it directly into the Svelte widget.

* does not run on iOS Safari, possible macOS Safari as well *

Embedding Widgets

Because the widgets run entirely client-side, feel free to embed them anywhere on your own site using the code snippets below. They are web components that use a shadow DOM so should always look the same no matter where they are embedded. Kind of like frame, but for the modern age.

<script src="https://www.cluoma.com/js/mnist_widget.js"></script>
<div><mnist-checker-widget /></div>
<script src="https://www.cluoma.com/js/mnist_convolution_widget.js"></script>
<div><mnist-convolution-checker-widget /></div>

Full source code for this project is posted on my GitHub.


Tags:


Neural Network from Scratch


It’s been a while since my last machine learning project: implementing a decision tree in Julia. This time I wanted to take a closer look at neural networks. I was recently shown an amazing book 'Neural Networks and Deep Learning' by Michael Nielson. He does a great job distilling the basics to a point where his explanations become intuitive. I won't be able to explain anything as well as he does so please check out his book.

The most basic neural networks are, as it turns out, surprisingly simple. It is possible to derive methods for building and training neural networks using only basic linear algebra and calculus. Neural networks have also been around for quite some time but it wasn’t until backpropagation was suggested as a way of training networks in the 70's that they really took off. The complexity of them stems somewhat from the sheer size of networks. Modern computer hardware and new scientific computing methods were required for neural networks to reach the popularity they have today.

Backpropagation is the key to training neural networks. Essentially, backpropagation takes the error at the output of a network and updates weights, within the network, based on how much they contributed to that error. By calculating the error from a sample and adjusting the weights accordingly over many, many iterations the network can be trained.

So in keeping with my previous project, I implemented a basic backpropogation algorithm in C for training on the popular MNIST dataset. I used a combination of the GNU Scientific Library and OpenBLAS for all the heavy number crunching. For the network itself I went with 2 hidden layers (4 total, including input and output layers) of 100 and 30 neurons. Below is the result after training on 50,000 images:


The green shows accuracy on training data and the blue shows the neural networks accuracy on a separate set of testing data. The x-axis shows the number of epochs, or the number of times the backpropagation went through all the training data and updated the network. Interestingly, after about 100 epochs the accuracy on the test data starts to decrease slightly. This is a sign that the network was overfitting to the training data. However, after around 180 epochs there is some disruption which ended up increasing the accuracy on both the training and testing data sets. Overall the accuracy was 99.74% and 97.14% on the training and testing data respectively.

As a final test, I got my lovely wife to draw any number on the computer (she chose '4'). I then fed this into the neural network to get see if it could identify what she wrote:


Clearly there is something to these neural networks after-all.

Thank you for reading. Please check out Michael's book if you want to know more about neural networks. Also check out the code I wrote for this network on my Github.


Tags: